{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "16de7336",
   "metadata": {},
   "source": [
    "# OpenAI Function Calling In LangChain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb41f5f4-df8d-4d04-9eaa-193b8c29b00b",
   "metadata": {
    "height": 115,
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import openai\n",
    "\n",
    "from dotenv import load_dotenv, find_dotenv\n",
    "_ = load_dotenv(find_dotenv()) # read local .env file\n",
    "openai.api_key = os.environ['OPENAI_API_KEY']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa1dddf9-8e44-4454-9d44-f8372cccf5ac",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "from typing import List\n",
    "from pydantic import BaseModel, Field"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad68931a-f806-4ea9-969c-93b3902baf9b",
   "metadata": {},
   "source": [
    "## Pydantic Syntax\n",
    "\n",
    "Pydantic data classes are a blend of Python's data classes with the validation power of Pydantic. \n",
    "\n",
    "They offer a concise way to define data structures while ensuring that the data adheres to specified types and constraints.\n",
    "\n",
    "In standard python you would create a class like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1557226-36e2-484b-a2fb-bb7e3180342c",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "class User:\n",
    "    def __init__(self, name: str, age: int, email: str):\n",
    "        self.name = name\n",
    "        self.age = age\n",
    "        self.email = email"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11b9b584-74dc-49b8-a7fe-3865368774e9",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "foo = User(name=\"Joe\",age=32, email=\"joe@gmail.com\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f6f9e9c-b83a-4859-8e65-e6488e05a071",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "foo.name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10a6a0de-d7dc-414d-baaf-fa43c6d1f410",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "foo = User(name=\"Joe\",age=\"bar\", email=\"joe@gmail.com\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "613b7b12-f061-44bc-989d-433cab609164",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "foo.age"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c541cb8d-fc55-4c94-a04f-a877cccf10ec",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "class pUser(BaseModel):\n",
    "    name: str\n",
    "    age: int\n",
    "    email: str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27394d22-73e3-4918-9bdf-18cd7c973942",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "foo_p = pUser(name=\"Jane\", age=32, email=\"jane@gmail.com\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49f25241-ff47-454f-bac4-ba20ab937d70",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "foo_p.name"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77405fab-9029-47e3-a6d3-f8338c2fefdc",
   "metadata": {},
   "source": [
    "**Note**: The next cell is expected to fail."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37030df3-ec11-4523-ac66-b88f90099d1b",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "foo_p = pUser(name=\"Jane\", age=\"bar\", email=\"jane@gmail.com\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "911e7677-cc5d-4957-b6c3-b3ba1493de33",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "class Class(BaseModel):\n",
    "    students: List[pUser]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14920e50-688e-4dd4-9207-9b59c6b018c6",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "obj = Class(\n",
    "    students=[pUser(name=\"Jane\", age=32, email=\"jane@gmail.com\")]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cede1035-7581-4203-bab7-b8e6363c931f",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "obj"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9c12cef-3a2d-46da-9c45-9a117e10f4a4",
   "metadata": {},
   "source": [
    "## Pydantic to OpenAI function definition\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "617ceea9-009f-4325-adae-85ab29fccd68",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "class WeatherSearch(BaseModel):\n",
    "    \"\"\"Call this with an airport code to get the weather at that airport\"\"\"\n",
    "    airport_code: str = Field(description=\"airport code to get weather for\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b22a438c-6692-47f9-9e00-1a95d04c6dd3",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "from langchain.utils.openai_functions import convert_pydantic_to_openai_function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97e152c4-8d04-4a02-b363-ee7691f60e31",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "weather_function = convert_pydantic_to_openai_function(WeatherSearch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f36b041e-bd28-4e25-a0c1-fbeeeee4ae53",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "weather_function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5d7c573-2f84-441d-ab73-a3dd263318d4",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "class WeatherSearch1(BaseModel):\n",
    "    airport_code: str = Field(description=\"airport code to get weather for\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d99b688-a9a7-4446-977f-07918a5d93e1",
   "metadata": {},
   "source": [
    "**Note**: The next cell is expected to generate an error."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb08f095-8190-41c5-b49c-e3580cedf992",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "convert_pydantic_to_openai_function(WeatherSearch1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed22668a-e188-45a5-844e-deee62f9bf51",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "class WeatherSearch2(BaseModel):\n",
    "    \"\"\"Call this with an airport code to get the weather at that airport\"\"\"\n",
    "    airport_code: str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e001e87-4338-4720-99b3-9dc4cb3e4faf",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "convert_pydantic_to_openai_function(WeatherSearch2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a35261e8-2d36-43a8-a051-a79bef35c8dd",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from langchain.chat_models import ChatOpenAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93bda0a6-9206-407a-b0da-966f9442a40c",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "model = ChatOpenAI()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afe342d4-a7ef-49cd-b760-aa9a176d64d5",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "model.invoke(\"what is the weather in SF today?\", functions=[weather_function])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "511e12b6-bcfb-4862-b377-4251de9969ea",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "model_with_function = model.bind(functions=[weather_function])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de6241d9-667c-4b97-a50f-95c046fa640c",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "model_with_function.invoke(\"what is the weather in sf?\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae78d6dd-bb38-4e55-9b65-0ef9005a52b9",
   "metadata": {},
   "source": [
    "## Forcing it to use a function\n",
    "\n",
    "We can force the model to use a function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7d2fbc0-df39-4e93-a22f-39a6285272b4",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "model_with_forced_function = model.bind(functions=[weather_function], function_call={\"name\":\"WeatherSearch\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd9f4063-9e15-41d7-9cf9-253548534176",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "model_with_forced_function.invoke(\"what is the weather in sf?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "314ca7e6-b77c-4b9d-9c93-da6ef3c9c6f8",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "model_with_forced_function.invoke(\"hi!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ac391c3-cd81-4423-a33e-6583ec534850",
   "metadata": {},
   "source": [
    "## Using in a chain\n",
    "\n",
    "We can use this model bound to function in a chain as we normally would"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c12c86df-d628-4176-9f4e-24fb5a953a5d",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "from langchain.prompts import ChatPromptTemplate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83f00dc6-5d22-44a5-a0a0-4fe1ab8167ac",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "prompt = ChatPromptTemplate.from_messages([\n",
    "    (\"system\", \"You are a helpful assistant\"),\n",
    "    (\"user\", \"{input}\")\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1eee3f5f-2176-4777-8c8a-2a197acc47a7",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "chain = prompt | model_with_function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8587907e-b4c3-4acd-9e58-1137047d0fee",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "chain.invoke({\"input\": \"what is the weather in sf?\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f317408d-de5e-4774-993e-a8ac31a2f5fe",
   "metadata": {},
   "source": [
    "## Using multiple functions\n",
    "\n",
    "Even better, we can pass a set of function and let the LLM decide which to use based on the question context."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48c8e42e-f84e-4822-b1ee-a9955fa301c4",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "class ArtistSearch(BaseModel):\n",
    "    \"\"\"Call this to get the names of songs by a particular artist\"\"\"\n",
    "    artist_name: str = Field(description=\"name of artist to look up\")\n",
    "    n: int = Field(description=\"number of results\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a51599b-ee32-4f74-9925-b6264bf43242",
   "metadata": {
    "height": 80
   },
   "outputs": [],
   "source": [
    "functions = [\n",
    "    convert_pydantic_to_openai_function(WeatherSearch),\n",
    "    convert_pydantic_to_openai_function(ArtistSearch),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0cc9e3b-ba38-4eff-b285-d02ee5963725",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "model_with_functions = model.bind(functions=functions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d01cd501-03b4-4207-b1be-0c33c12a0fa5",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "model_with_functions.invoke(\"what is the weather in sf?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31599518-7387-4d08-9d68-8f5ba7282e8b",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "model_with_functions.invoke(\"what are three songs by taylor swift?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e6df989-6ea3-48af-b2c0-10978e0f4142",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "model_with_functions.invoke(\"hi!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b29653b7-1382-4375-b681-10c51964fff5",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b463f4b3-1105-4c64-9a23-32dabe86e71b",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7651222c-ec2c-4e7e-a397-b8f962906037",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91bc3f83-1631-4f34-b852-e7325b8f39c2",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfd13a31-531d-493b-b707-9525915d2c02",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "592876b9-925b-4ab9-ada8-bb9af0a4b523",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a0e1e52-101d-4f18-b3cb-f45ef332ccbe",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d22ea80-98f6-4599-a48a-0a41284994d0",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c66221b-e1a9-4ea2-833b-5ab6184a8454",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0c1ab37-f001-42ce-b18f-b40dfb005f43",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7eb94a2-ad8f-41da-bbff-a627924f7d34",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f380d482-b37f-4806-ab36-09a8ee60d258",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
